# Lint as: python3
"""AMDSHARK Downloader"""
# Requirements : Put amdshark_tank in AMDSHARK directory
#   /AMDSHARK
#     /gen_amdshark_tank
#       /tflite
#         /albert_lite_base
#         /...model_name...
#       /tf
#       /pytorch
#
#
#

import numpy as np
import os
from tqdm.std import tqdm
import sys
from pathlib import Path
from amdshark.parser import amdshark_args
from google.cloud import storage


def download_public_file(
    full_gs_url, destination_folder_name, single_file=False
):
    """Downloads a public blob from the bucket."""
    # bucket_name = "gs://your-bucket-name/path/to/file"
    # destination_file_name = "local/path/to/file"

    storage_client = storage.Client.create_anonymous_client()
    bucket_name = full_gs_url.split("/")[2]
    source_blob_name = None
    dest_filename = None
    desired_file = None
    if single_file:
        desired_file = full_gs_url.split("/")[-1]
        source_blob_name = "/".join(full_gs_url.split("/")[3:-1])
        destination_folder_name, dest_filename = os.path.split(
            destination_folder_name
        )
    else:
        source_blob_name = "/".join(full_gs_url.split("/")[3:])
    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=source_blob_name)
    if not os.path.exists(destination_folder_name):
        os.mkdir(destination_folder_name)
    for blob in blobs:
        blob_name = blob.name.split("/")[-1]
        if single_file:
            if blob_name == desired_file:
                destination_filename = os.path.join(
                    destination_folder_name, dest_filename
                )
                with open(destination_filename, "wb") as f:
                    with tqdm.wrapattr(
                        f, "write", total=blob.size
                    ) as file_obj:
                        storage_client.download_blob_to_file(blob, file_obj)
            else:
                continue

        else:
            destination_filename = os.path.join(
                destination_folder_name, blob_name
            )
            if os.path.isdir(destination_filename):
                continue
            with open(destination_filename, "wb") as f:
                with tqdm.wrapattr(f, "write", total=blob.size) as file_obj:
                    storage_client.download_blob_to_file(blob, file_obj)


input_type_to_np_dtype = {
    "float32": np.float32,
    "float64": np.float64,
    "bool": np.bool_,
    "int32": np.int32,
    "int64": np.int64,
    "uint8": np.uint8,
    "int8": np.int8,
}

# Save the model in the home local so it needn't be fetched everytime in the CI.
home = str(Path.home())
alt_path = os.path.join(os.path.dirname(__file__), "../gen_amdshark_tank/")
custom_path = amdshark_args.local_tank_cache

if custom_path is not None:
    if not os.path.exists(custom_path):
        os.mkdir(custom_path)

    WORKDIR = custom_path

    print(f"Using {WORKDIR} as local amdshark_tank cache directory.")

elif os.path.exists(alt_path):
    WORKDIR = alt_path
    print(
        f"Using {WORKDIR} as amdshark_tank directory. Delete this directory if you aren't working from locally generated amdshark_tank."
    )
else:
    WORKDIR = os.path.join(home, ".local/amdshark_tank/")
    print(
        f"amdshark_tank local cache is located at {WORKDIR} . You may change this by setting the --local_tank_cache= flag"
    )
os.makedirs(WORKDIR, exist_ok=True)


# Checks whether the directory and files exists.
def check_dir_exists(model_name, frontend="torch", dynamic=""):
    model_dir = os.path.join(WORKDIR, model_name)

    # Remove the _tf keyword from end only for non-SD models.
    if not any(model in model_name for model in ["clip", "unet", "vae"]):
        if frontend in ["tf", "tensorflow"]:
            model_name = model_name[:-3]
        elif frontend in ["tflite"]:
            model_name = model_name[:-7]
        elif frontend in ["torch", "pytorch"]:
            model_name = model_name[:-6]

    model_mlir_file_name = f"{model_name}{dynamic}_{frontend}.mlir"

    if os.path.isdir(model_dir):
        if (
            os.path.isfile(os.path.join(model_dir, model_mlir_file_name))
            and os.path.isfile(os.path.join(model_dir, "function_name.npy"))
            and os.path.isfile(os.path.join(model_dir, "inputs.npz"))
            and os.path.isfile(os.path.join(model_dir, "golden_out.npz"))
            and os.path.isfile(os.path.join(model_dir, "hash.npy"))
        ):
            print(
                f"""Model artifacts for {model_name} found at {WORKDIR}..."""
            )
            return True
    return False


def _internet_connected():
    import requests as req

    try:
        req.get("http://1.1.1.1")
        return True
    except:
        return False


def get_git_revision_short_hash() -> str:
    import subprocess

    if amdshark_args.amdshark_prefix is not None:
        prefix_kw = amdshark_args.amdshark_prefix
    else:
        import json

        dir_path = os.path.dirname(os.path.realpath(__file__))
        src = os.path.join(dir_path, "..", "tank_version.json")
        with open(src, "r") as f:
            data = json.loads(f.read())
            prefix_kw = data["version"]
    print(f"Checking for updates from gs://amdshark_tank/{prefix_kw}")
    return prefix_kw


def get_amdsharktank_prefix():
    tank_prefix = ""
    if not _internet_connected():
        print(
            "No internet connection. Using the model already present in the tank."
        )
        tank_prefix = "none"
    else:
        desired_prefix = get_git_revision_short_hash()
        storage_client_a = storage.Client.create_anonymous_client()
        base_bucket_name = "amdshark_tank"
        base_bucket = storage_client_a.bucket(base_bucket_name)
        dir_blobs = base_bucket.list_blobs(prefix=f"{desired_prefix}")
        for blob in dir_blobs:
            dir_blob_name = blob.name.split("/")
            if desired_prefix in dir_blob_name[0]:
                tank_prefix = dir_blob_name[0]
                break
            else:
                continue
        if tank_prefix == "":
            print(
                f"amdshark_tank bucket not found matching ({desired_prefix}). Defaulting to nightly."
            )
            tank_prefix = "nightly"
    return tank_prefix


# Downloads the torch model from gs://amdshark_tank dir.
def download_model(
    model_name,
    dynamic=False,
    tank_url=None,
    frontend=None,
    tuned=None,
    import_args={"batch_size": 1},
):
    model_name = model_name.replace("/", "_")
    dyn_str = "_dynamic" if dynamic else ""
    os.makedirs(WORKDIR, exist_ok=True)
    amdshark_args.amdshark_prefix = get_amdsharktank_prefix()
    if import_args["batch_size"] and import_args["batch_size"] != 1:
        model_dir_name = (
            model_name
            + "_"
            + frontend
            + "_BS"
            + str(import_args["batch_size"])
        )
    elif any(model in model_name for model in ["clip", "unet", "vae"]):
        # TODO(Ean Garvey): rework extended naming such that device is only included in model_name after .vmfb compilation.
        model_dir_name = model_name
    else:
        model_dir_name = model_name + "_" + frontend
    model_dir = os.path.join(WORKDIR, model_dir_name)

    if not tank_url:
        tank_url = "gs://amdshark_tank/" + amdshark_args.amdshark_prefix

    full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
    if not check_dir_exists(
        model_dir_name, frontend=frontend, dynamic=dyn_str
    ):
        print(
            f"Downloading artifacts for model {model_name} from: {full_gs_url}"
        )
        download_public_file(full_gs_url, model_dir)

    elif amdshark_args.force_update_tank == True:
        print(
            f"Force-updating artifacts for model {model_name} from: {full_gs_url}"
        )
        download_public_file(full_gs_url, model_dir)
    else:
        if not _internet_connected():
            print(
                "No internet connection. Using the model already present in the tank."
            )
        else:
            local_hash = str(np.load(os.path.join(model_dir, "hash.npy")))
            gs_hash_url = (
                tank_url.rstrip("/") + "/" + model_dir_name + "/hash.npy"
            )
            download_public_file(
                gs_hash_url,
                os.path.join(model_dir, "upstream_hash.npy"),
                single_file=True,
            )
            try:
                upstream_hash = str(
                    np.load(os.path.join(model_dir, "upstream_hash.npy"))
                )
            except FileNotFoundError:
                print(f"Model artifact hash not found at {model_dir}.")
                upstream_hash = None
            if local_hash != upstream_hash and amdshark_args.update_tank == True:
                print(f"Updating artifacts for model {model_name}...")
                download_public_file(full_gs_url, model_dir)

            elif local_hash != upstream_hash:
                print(
                    "Hash does not match upstream in gs://amdshark_tank/. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
                )
            else:
                print(
                    "Local and upstream hashes match. Using cached model artifacts."
                )

    model_dir = os.path.join(WORKDIR, model_dir_name)
    tuned_str = "" if tuned is None else "_" + tuned
    suffix = f"{dyn_str}_{frontend}{tuned_str}.mlir"
    mlir_filename = os.path.join(model_dir, model_name + suffix)
    print(
        f"Verifying that model artifacts were downloaded successfully to {mlir_filename}..."
    )
    if not os.path.exists(mlir_filename):
        from tank.generate_amdsharktank import gen_amdshark_files

        print(
            "The model data was not found. Trying to generate artifacts locally."
        )
        gen_amdshark_files(model_name, frontend, WORKDIR, import_args)

    assert os.path.exists(mlir_filename), f"MLIR not found at {mlir_filename}"
    function_name = str(np.load(os.path.join(model_dir, "function_name.npy")))
    inputs = np.load(os.path.join(model_dir, "inputs.npz"))
    golden_out = np.load(os.path.join(model_dir, "golden_out.npz"))

    inputs_tuple = tuple([inputs[key] for key in inputs])
    golden_out_tuple = tuple([golden_out[key] for key in golden_out])
    return mlir_filename, function_name, inputs_tuple, golden_out_tuple
